#!/usr/bin/env python
# -*- coding: utf-8 -*-
from email import iterators
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset

class IIDBatchSampler:
    def __init__(self, dataset, minibatch_size, iterations):
        self.length = len(dataset)
        self.minibatch_size = minibatch_size
        self.iterations = iterations

    def __iter__(self):
        for _ in range(self.iterations):
            indices = np.where(torch.rand(self.length) < (self.minibatch_size / self.length))[0]
            if indices.size > 0:
                yield indices

    def __len__(self):
        return self.iterations

class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)


class LocalUpdate(object):
    def __init__(self, args, dataset, idxs, logger):
        self.args = args
        self.logger = logger
        self.trainloader = self.train_loader(
            dataset, list(idxs))
        self.iterator=iter(self.trainloader)
        self.device = 'cuda' if args.gpu else 'cpu'
        # Default criterion set to NLL loss function
        self.criterion = nn.NLLLoss().to(self.device)
        # self.criterion=nn.CrossEntropyLoss()
    

    def train_loader(self, dataset, idxs):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        trainloader = DataLoader(DatasetSplit(dataset, idxs),
                                 batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def update_weights(self, model, global_round, lr_epoch):
        # Set mode to train model
        model.train()
        if self.args.mode=='STO-SIGNSGD' or self.args.mode=='SIGNSGD':            
            # Set optimizer for the local updates
            if self.args.optimizer == 'sgd':
                optimizer = torch.optim.SGD(model.parameters(), lr=lr_epoch,
                                            momentum=self.args.momentum)
            elif self.args.optimizer == 'adam':
                optimizer = torch.optim.Adam(model.parameters(), lr=lr_epoch,
                                             weight_decay=1e-4)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.args.gamma)
            try:
                 images,labels=next(self.iterator)
            except StopIteration as e:
                self.iterator=iter(self.trainloader)
                images,labels=next(self.iterator)
            images, labels = images.to(self.device), labels.to(self.device)
            model.zero_grad()
            log_probs = model(images)
            loss = self.criterion(log_probs, labels)
            loss.backward()
            optimizer.step()
        else:
             # Set optimizer for the local updates
            if self.args.optimizer == 'sgd':
                optimizer = torch.optim.SGD(model.parameters(), lr=lr_epoch,
                                            momentum=self.args.momentum)
            elif self.args.optimizer == 'adam':
                optimizer = torch.optim.Adam(model.parameters(), lr=lr_epoch,
                                             weight_decay=1e-4)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.args.gamma)
            for iter in range(self.args.local_ep):
                for batch_idx, (images, labels) in enumerate(self.trainloader):
                    images, labels = images.to(self.device), labels.to(self.device)

                    model.zero_grad()
                    log_probs = model(images)
                    loss = self.criterion(log_probs, labels)
                    loss.backward()
                    optimizer.step()
                    if self.args.verbose and (batch_idx % 10 == 0):
                        print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]'.format(
                            global_round, iter, batch_idx * len(images),
                            len(self.trainloader.dataset),
                            100. * batch_idx / len(self.trainloader)))
        scheduler.step()    
        lr = scheduler.get_lr()[0]
        return model,lr

    def inference(self, model):
        """ Returns the inference accuracy.
        """

        model.eval()
        total, correct = 0.0, 0.0

        for batch_idx, (images, labels) in enumerate(self.trainloader):
            images, labels = images.to(self.device), labels.to(self.device)

            # Inference
            outputs = model(images)

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)

        accuracy = correct/total
        return accuracy


def test_inference(args, model, test_dataset):
    """ Returns the test accuracy and loss.
    """

    model.eval()
    total, correct = 0.0, 0.0

    device = 'cuda' if args.gpu else 'cpu'
    testloader = DataLoader(test_dataset, batch_size=args.local_bs,
                            shuffle=False)

    for batch_idx, (images, labels) in enumerate(testloader):
        images, labels = images.to(device), labels.to(device)

        # Inference
        outputs = model(images)

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)

    accuracy = correct/total
    return accuracy


